'''
Author: wjh
Date: 2022-12-18 01:32:38
FilePath: \scripts\main.py
'''
'''
Author: wjh
Date: 2022-12-18 01:32:38
FilePath: \scripts\main.py
'''

from jiayan import load_lm
from jiayan import CharHMMTokenizer
from subword_nmt.apply_bpe import BPE
import codecs
import jieba
from fairseq.models.transformer import TransformerModel
from clf.classifier import Classifier
import torch
from transformers import AutoTokenizer
import logging
jieba.setLogLevel(logging.INFO)

def classifier(model,txt,tokenizer_clf):
    # device="cuda" if torch.cuda.is_available() else "cpu"
    # model = model.to(device)
    token=tokenizer_clf(txt,return_tensors="pt")["input_ids"]
    output=model(token)
    if output[0,0]>output[0,1]:
        return 0
    else:
        return 1
        
if __name__ == "__main__":
    
    flag = 0 # 0为an-zh，1为zh-an
    # 分类器
    model=Classifier(freeze_bert=True)
    model.load_state_dict(torch.load('./model/clf_model.pth'),strict=True)
    tokenizer_clf=AutoTokenizer.from_pretrained("bert-base-chinese")
    # 翻译
    model_an_zh = TransformerModel.from_pretrained(".",checkpoint_file="./model/an-zh.pt",data_name_or_path="./data-bin")
    model_zh_an = TransformerModel.from_pretrained(".",checkpoint_file="./model/zh-an.pt",data_name_or_path="./data-bin_zh-an")
    
    # load
    lm = load_lm('./jiayan_models/jiayan.klm')
    
    while True:
        print("-----------------------------------------------------------------")
        src = input(">>")# 输入
        
        flag = classifier(model,src,tokenizer_clf)# 分类
        print(flag)
        # 分词
        if not flag:
            tokenizer = CharHMMTokenizer(lm)
            ls = list(tokenizer.tokenize(src))
            tokenizer = " ".join(ls)
            codes = codecs.open("./file/bpecode.an",encoding='utf-8')
        else:
            codes = codecs.open("./file/bpecode.zh",encoding='utf-8')
            tokenizer = list(jieba.cut(src))
            tokenizer = " ".join(tokenizer)
        
        # bpe分词   
        bpe = BPE(codes)
        res = bpe.process_line(tokenizer)
        codes.close()
        
        # forward
        if not flag:
            aftermodel = model_an_zh.translate(res)
        else:
            aftermodel = model_zh_an.translate(res)
        proc1 = aftermodel.split(" ")
        proc2 = [i.strip("@") for i in proc1]
        output = "".join(proc2)
        print(output)